import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.functional import softmax

class DiscreteEmbeddingLayer(nn.Module):
    def __init__(self,
                 num_max_tokens = 1000,
                 embed_dim: int = 256
                 ):
        super().__init__()

        self.embed_layer = nn.Embedding(num_max_tokens, embed_dim)

        self.initialize_weights()

    def initialize_weights(self):

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            w = m.weight.data
            torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, 0, 0.02)

    def forward(self, x):
        x = self.embed_layer(x)
        return x


class AttentionFusion(nn.Module):
    def __init__(self, state_dim, action_dim, embed_dim):
        super(AttentionFusion, self).__init__()
        self.query = nn.Linear(state_dim, embed_dim)
        self.key = nn.Linear(action_dim, embed_dim)
        self.value = nn.Linear(action_dim, embed_dim)

    def forward(self, state, action):
        """
        Adapts to input shapes:
        - 2D input: (batch_size, feature_dim)
        - 3D input: (batch_size, num_envs, feature_dim)
        """
        # Check if the input is 3D
        is_3d = state.dim() == 3 and action.dim() == 3

        # If 2D input, add an environment dimension
        if not is_3d:
            state = state.unsqueeze(1)  # (batch_size, 1, feature_dim)
            action = action.unsqueeze(1)  # (batch_size, 1, feature_dim)

        # Generate Query, Key, and Value
        q = self.query(state)  # (batch_size, num_envs, embed_dim)
        k = self.key(action)  # (batch_size, num_envs, embed_dim)
        v = self.value(action)  # (batch_size, num_envs, embed_dim)

        # Compute attention weights
        attention_weights = softmax(torch.bmm(q, k.transpose(1, 2)), dim=-1)  # (batch_size, num_envs, num_envs)

        # Apply attention to Value
        fused = torch.bmm(attention_weights, v)  # (batch_size, num_envs, embed_dim)

        # If input was 2D, remove the added environment dimension
        if not is_3d:
            fused = fused.squeeze(1)  # (batch_size, embed_dim)

        return fused
